#include <torchtest/common.h>
#include <torchtest/test_runner.h>

// 测试2: 简单模型
void test_simple_model() {
  auto model = torch::nn::Linear(10, 5);
  auto input = torch::randn({1, 10});

  auto output = model->forward(input);
  print_tensor(output, "Model Output");
}

// 注册测试
void register_test2() {
  test_runner::register_test("Simple Model", test_simple_model);
}
static bool test2_registered = register_test2();